{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.5.1+cu121\n",
      "use cuda\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import math\n",
    "\n",
    "print(torch.__version__)\n",
    "USE_GPU = True\n",
    "dtype = torch.float32\n",
    "if USE_GPU and torch.cuda.is_available():\n",
    "    device = torch.device('cuda')\n",
    "    print(\"use cuda\")\n",
    "else:\n",
    "    device = torch.device('cpu')\n",
    "    print(\"use cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup cell.\n",
    "import time, os, json\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from utils.coco_utils import load_coco_data, sample_coco_minibatch, decode_captions\n",
    "\n",
    "%matplotlib inline\n",
    "plt.rcParams['figure.figsize'] = (10.0, 8.0) # Set default size of plots.\n",
    "plt.rcParams['image.interpolation'] = 'nearest'\n",
    "plt.rcParams['image.cmap'] = 'gray'\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "def rel_error(x, y):\n",
    "    \"\"\" returns relative error \"\"\"\n",
    "    return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "base dir  D:\\workplace\\simple-dnn\\utils\\../datasets/coco_captioning\n",
      "-------------------------------\n",
      "train_captions <class 'numpy.ndarray'> (400135, 17) int32\n",
      "train_image_idxs <class 'numpy.ndarray'> (400135,) int32\n",
      "val_captions <class 'numpy.ndarray'> (195954, 17) int32\n",
      "val_image_idxs <class 'numpy.ndarray'> (195954,) int32\n",
      "train_features <class 'numpy.ndarray'> (82783, 512) float32\n",
      "val_features <class 'numpy.ndarray'> (40504, 512) float32\n",
      "idx_to_word <class 'list'> 1004\n",
      "word_to_idx <class 'dict'> 1004\n",
      "train_urls <class 'numpy.ndarray'> (82783,) <U63\n",
      "val_urls <class 'numpy.ndarray'> (40504,) <U63\n"
     ]
    }
   ],
   "source": [
    "# Load COCO data from disk into a dictionary.\n",
    "data = load_coco_data(pca_features=True)\n",
    "print(\"-------------------------------\")\n",
    "# Print out all the keys and values from the data dictionary.\n",
    "for k, v in data.items():\n",
    "    if type(v) == np.ndarray:\n",
    "        print(k, type(v), v.shape, v.dtype)\n",
    "    else:\n",
    "        print(k, type(v), len(v))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def transformer_temporal_softmax_loss(x, y, mask):\n",
    "    N, T, V = x.shape\n",
    "\n",
    "    x_flat = x.reshape(N * T, V)\n",
    "    y_flat = y.reshape(N * T)\n",
    "    mask_flat = mask.reshape(N * T)\n",
    "    loss = torch.nn.functional.cross_entropy(x_flat,  y_flat, reduction='none')\n",
    "    loss = torch.mul(loss, mask_flat)\n",
    "    loss = torch.mean(loss)\n",
    "\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "The shape of the 2D attn_mask is torch.Size([16, 16]), but should be (1, 1).",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[11], line 86\u001b[0m\n\u001b[0;32m     84\u001b[0m t_captions_out \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mLongTensor(captions_out)\n\u001b[0;32m     85\u001b[0m t_mask \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mLongTensor(mask)\n\u001b[1;32m---> 86\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mt_features\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt_captions_in\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     88\u001b[0m loss \u001b[38;5;241m=\u001b[39m transformer_temporal_softmax_loss(logits, t_captions_out, t_mask)\n\u001b[0;32m     89\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mloss:\u001b[39m\u001b[38;5;124m\"\u001b[39m, loss)\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\torch\\nn\\modules\\module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1734\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m   1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1736\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\torch\\nn\\modules\\module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m   1745\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1746\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1747\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m   1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
      "Cell \u001b[1;32mIn[11], line 33\u001b[0m, in \u001b[0;36mCaptioningTransformer.forward\u001b[1;34m(self, features, captions)\u001b[0m\n\u001b[0;32m     31\u001b[0m feature_vec \u001b[38;5;241m=\u001b[39m feature_vec\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m     32\u001b[0m \u001b[38;5;66;03m# features = self.encoder(features)\u001b[39;00m\n\u001b[1;32m---> 33\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecoder\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcaption_embed\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfeatures\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munsqueeze\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtgt_mask\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     34\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m-----\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m     35\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_linear(output)\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\torch\\nn\\modules\\module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1734\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m   1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1736\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\torch\\nn\\modules\\module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m   1745\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1746\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1747\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m   1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\torch\\nn\\modules\\transformer.py:602\u001b[0m, in \u001b[0;36mTransformerDecoder.forward\u001b[1;34m(self, tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, tgt_is_causal, memory_is_causal)\u001b[0m\n\u001b[0;32m    599\u001b[0m tgt_is_causal \u001b[38;5;241m=\u001b[39m _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len)\n\u001b[0;32m    601\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m mod \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers:\n\u001b[1;32m--> 602\u001b[0m     output \u001b[38;5;241m=\u001b[39m \u001b[43mmod\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    603\u001b[0m \u001b[43m        \u001b[49m\u001b[43moutput\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    604\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmemory\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    605\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtgt_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtgt_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    606\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmemory_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmemory_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    607\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtgt_key_padding_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtgt_key_padding_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    608\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmemory_key_padding_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmemory_key_padding_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    609\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtgt_is_causal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtgt_is_causal\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    610\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmemory_is_causal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmemory_is_causal\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    611\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    613\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m    614\u001b[0m     output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm(output)\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\torch\\nn\\modules\\module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1734\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m   1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1736\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\torch\\nn\\modules\\module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m   1745\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1746\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1747\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m   1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\torch\\nn\\modules\\transformer.py:1087\u001b[0m, in \u001b[0;36mTransformerDecoderLayer.forward\u001b[1;34m(self, tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, tgt_is_causal, memory_is_causal)\u001b[0m\n\u001b[0;32m   1084\u001b[0m     x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ff_block(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm3(x))\n\u001b[0;32m   1085\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m   1086\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm1(\n\u001b[1;32m-> 1087\u001b[0m         x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sa_block\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtgt_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtgt_key_padding_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtgt_is_causal\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1088\u001b[0m     )\n\u001b[0;32m   1089\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm2(\n\u001b[0;32m   1090\u001b[0m         x\n\u001b[0;32m   1091\u001b[0m         \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_mha_block(\n\u001b[0;32m   1092\u001b[0m             x, memory, memory_mask, memory_key_padding_mask, memory_is_causal\n\u001b[0;32m   1093\u001b[0m         )\n\u001b[0;32m   1094\u001b[0m     )\n\u001b[0;32m   1095\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm3(x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ff_block(x))\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\torch\\nn\\modules\\transformer.py:1107\u001b[0m, in \u001b[0;36mTransformerDecoderLayer._sa_block\u001b[1;34m(self, x, attn_mask, key_padding_mask, is_causal)\u001b[0m\n\u001b[0;32m   1100\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_sa_block\u001b[39m(\n\u001b[0;32m   1101\u001b[0m     \u001b[38;5;28mself\u001b[39m,\n\u001b[0;32m   1102\u001b[0m     x: Tensor,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m   1105\u001b[0m     is_causal: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[0;32m   1106\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[1;32m-> 1107\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mself_attn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m   1108\u001b[0m \u001b[43m        \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1109\u001b[0m \u001b[43m        \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1110\u001b[0m \u001b[43m        \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1111\u001b[0m \u001b[43m        \u001b[49m\u001b[43mattn_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattn_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1112\u001b[0m \u001b[43m        \u001b[49m\u001b[43mkey_padding_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey_padding_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1113\u001b[0m \u001b[43m        \u001b[49m\u001b[43mis_causal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_causal\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1114\u001b[0m \u001b[43m        \u001b[49m\u001b[43mneed_weights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m   1115\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m   1116\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdropout1(x)\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\torch\\nn\\modules\\module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1734\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m   1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1736\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\torch\\nn\\modules\\module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m   1745\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1746\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1747\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m   1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\torch\\nn\\modules\\activation.py:1368\u001b[0m, in \u001b[0;36mMultiheadAttention.forward\u001b[1;34m(self, query, key, value, key_padding_mask, need_weights, attn_mask, average_attn_weights, is_causal)\u001b[0m\n\u001b[0;32m   1342\u001b[0m     attn_output, attn_output_weights \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39mmulti_head_attention_forward(\n\u001b[0;32m   1343\u001b[0m         query,\n\u001b[0;32m   1344\u001b[0m         key,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m   1365\u001b[0m         is_causal\u001b[38;5;241m=\u001b[39mis_causal,\n\u001b[0;32m   1366\u001b[0m     )\n\u001b[0;32m   1367\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1368\u001b[0m     attn_output, attn_output_weights \u001b[38;5;241m=\u001b[39m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmulti_head_attention_forward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m   1369\u001b[0m \u001b[43m        \u001b[49m\u001b[43mquery\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1370\u001b[0m \u001b[43m        \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1371\u001b[0m \u001b[43m        \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1372\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membed_dim\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1373\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_heads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1374\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43min_proj_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1375\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43min_proj_bias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1376\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias_k\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1377\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias_v\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1378\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd_zero_attn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1379\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdropout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1380\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mout_proj\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1381\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mout_proj\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1382\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtraining\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1383\u001b[0m \u001b[43m        \u001b[49m\u001b[43mkey_padding_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey_padding_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1384\u001b[0m \u001b[43m        \u001b[49m\u001b[43mneed_weights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mneed_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1385\u001b[0m \u001b[43m        \u001b[49m\u001b[43mattn_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattn_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1386\u001b[0m \u001b[43m        \u001b[49m\u001b[43maverage_attn_weights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maverage_attn_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1387\u001b[0m \u001b[43m        \u001b[49m\u001b[43mis_causal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_causal\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1388\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1389\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_first \u001b[38;5;129;01mand\u001b[39;00m is_batched:\n\u001b[0;32m   1390\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m attn_output\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m0\u001b[39m), attn_output_weights\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\torch\\nn\\functional.py:6131\u001b[0m, in \u001b[0;36mmulti_head_attention_forward\u001b[1;34m(query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight, q_proj_weight, k_proj_weight, v_proj_weight, static_k, static_v, average_attn_weights, is_causal)\u001b[0m\n\u001b[0;32m   6129\u001b[0m     correct_2d_size \u001b[38;5;241m=\u001b[39m (tgt_len, src_len)\n\u001b[0;32m   6130\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m attn_mask\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m!=\u001b[39m correct_2d_size:\n\u001b[1;32m-> 6131\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[0;32m   6132\u001b[0m             \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe shape of the 2D attn_mask is \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mattn_mask\u001b[38;5;241m.\u001b[39mshape\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, but should be \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcorrect_2d_size\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m   6133\u001b[0m         )\n\u001b[0;32m   6134\u001b[0m     attn_mask \u001b[38;5;241m=\u001b[39m attn_mask\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\n\u001b[0;32m   6135\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m attn_mask\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m3\u001b[39m:\n",
      "\u001b[1;31mRuntimeError\u001b[0m: The shape of the 2D attn_mask is torch.Size([16, 16]), but should be (1, 1)."
     ]
    }
   ],
   "source": [
    "\n",
    "# 定义编码器和解码器的参数\n",
    "d_model = 512\n",
    "nhead = 8\n",
    "batch_size = 1\n",
    "num_encoder_layers = 6\n",
    "num_decoder_layers = 6\n",
    "dim_feedforward = 2048\n",
    "dropout = 0.1\n",
    "learning_rate = 1e-6\n",
    " \n",
    "# 定义Transformer模型\n",
    "class CaptioningTransformer(nn.Module):\n",
    "    def __init__(self, ntokens, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout):\n",
    "        super(CaptioningTransformer, self).__init__()\n",
    "        self.model_type = 'Transformer'\n",
    "        self.src_mask = None\n",
    "        self.visual_projection = nn.Linear(d_model, d_model)\n",
    "        self.embedding = nn.Embedding(len(data['word_to_idx']), d_model, padding_idx=0)\n",
    "        self.pos_encoder = PositionalEncoding(d_model, dropout)\n",
    "        encoder_layers = torch.nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)\n",
    "        self.encoder = torch.nn.TransformerEncoder(encoder_layers, num_encoder_layers)\n",
    "        decoder_layers = torch.nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout)\n",
    "        self.decoder = torch.nn.TransformerDecoder(decoder_layers, num_decoder_layers)\n",
    "        self.output_linear = nn.Linear(d_model, ntokens)\n",
    " \n",
    "    def forward(self, features, captions):\n",
    "        tgt_mask = self.generate_square_subsequent_mask(captions.shape[1]).to(features.device)\n",
    "\n",
    "        caption_embed = self.pos_encoder(self.embedding(captions))\n",
    "        feature_vec = self.visual_projection(features)\n",
    "        feature_vec = feature_vec.unsqueeze(1)\n",
    "        # features = self.encoder(features)\n",
    "        output = self.decoder(caption_embed, features.unsqueeze(1), tgt_mask)\n",
    "        print(\"-----\")\n",
    "        output = self.output_linear(output)\n",
    "        return F.log_softmax(output, dim=-1)\n",
    " \n",
    "    def generate_square_subsequent_mask(self, sz):\n",
    "        tgt_mask = torch.ones((sz, sz), dtype=bool)\n",
    "        tgt_mask = torch.tril(tgt_mask)\n",
    "        return tgt_mask\n",
    " \n",
    "# 定义位置编码\n",
    "class PositionalEncoding(nn.Module):\n",
    "    def __init__(self, d_model, dropout=0.1, max_len=5000):\n",
    "        super(PositionalEncoding, self).__init__()\n",
    "        self.dropout = nn.Dropout(p=dropout)\n",
    " \n",
    "        pe = torch.zeros(max_len, d_model)\n",
    "        position = torch.arange(0, max_len).unsqueeze(1)\n",
    "        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))\n",
    "        pe[:, 0::2] = torch.sin(position * div_term)\n",
    "        pe[:, 1::2] = torch.cos(position * div_term)\n",
    "        pe = pe.unsqueeze(0)\n",
    "        self.register_buffer('pe', pe)\n",
    " \n",
    "    def forward(self, x):\n",
    "        x = x + self.pe[:, :x.size(1), :x.size(2)]\n",
    "        return self.dropout(x)\n",
    " \n",
    "# 初始化模型\n",
    "ntokens = 1024  # 假设字典大小为256\n",
    "model = CaptioningTransformer(ntokens, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout)\n",
    "optimizer = torch.optim.Adam(model.parameters(), learning_rate) \n",
    "\n",
    "num_epochs = 100\n",
    "num_train = data[\"train_captions\"].shape[0]\n",
    "\n",
    "for t in range(num_epochs):\n",
    "    model.train()\n",
    "    minibatch = sample_coco_minibatch(\n",
    "        data, batch_size=batch_size, split=\"train\"\n",
    "    )\n",
    "\n",
    "    captions, features, urls = minibatch\n",
    "\n",
    "    captions_in = captions[:, :-1]\n",
    "    captions_out = captions[:, 1:]\n",
    "\n",
    "    mask = captions_out != 0\n",
    "\n",
    "    t_features = torch.Tensor(features)\n",
    "    t_captions_in = torch.LongTensor(captions_in)\n",
    "    t_captions_out = torch.LongTensor(captions_out)\n",
    "    t_mask = torch.LongTensor(mask)\n",
    "    logits = model(t_features, t_captions_in)\n",
    "\n",
    "    loss = transformer_temporal_softmax_loss(logits, t_captions_out, t_mask)\n",
    "    print(\"loss:\", loss)\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "shape '[10, 256, 64]' is invalid for input of size 153600",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[83], line 5\u001b[0m\n\u001b[0;32m      3\u001b[0m memory \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrand(\u001b[38;5;241m10\u001b[39m, \u001b[38;5;241m30\u001b[39m, \u001b[38;5;241m512\u001b[39m)\n\u001b[0;32m      4\u001b[0m tgt \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrand(\u001b[38;5;241m20\u001b[39m, \u001b[38;5;241m32\u001b[39m, \u001b[38;5;241m512\u001b[39m)\n\u001b[1;32m----> 5\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mtransformer_decoder\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtgt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmemory\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m      6\u001b[0m \u001b[38;5;28mprint\u001b[39m(out)\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\torch\\nn\\modules\\module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1734\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m   1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1736\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\torch\\nn\\modules\\module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m   1745\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1746\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1747\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m   1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\torch\\nn\\modules\\transformer.py:602\u001b[0m, in \u001b[0;36mTransformerDecoder.forward\u001b[1;34m(self, tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, tgt_is_causal, memory_is_causal)\u001b[0m\n\u001b[0;32m    599\u001b[0m tgt_is_causal \u001b[38;5;241m=\u001b[39m _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len)\n\u001b[0;32m    601\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m mod \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers:\n\u001b[1;32m--> 602\u001b[0m     output \u001b[38;5;241m=\u001b[39m \u001b[43mmod\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    603\u001b[0m \u001b[43m        \u001b[49m\u001b[43moutput\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    604\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmemory\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    605\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtgt_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtgt_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    606\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmemory_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmemory_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    607\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtgt_key_padding_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtgt_key_padding_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    608\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmemory_key_padding_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmemory_key_padding_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    609\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtgt_is_causal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtgt_is_causal\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    610\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmemory_is_causal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmemory_is_causal\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    611\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    613\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m    614\u001b[0m     output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm(output)\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\torch\\nn\\modules\\module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1734\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m   1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1736\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\torch\\nn\\modules\\module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m   1745\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1746\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1747\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m   1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\torch\\nn\\modules\\transformer.py:1091\u001b[0m, in \u001b[0;36mTransformerDecoderLayer.forward\u001b[1;34m(self, tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, tgt_is_causal, memory_is_causal)\u001b[0m\n\u001b[0;32m   1085\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m   1086\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm1(\n\u001b[0;32m   1087\u001b[0m         x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal)\n\u001b[0;32m   1088\u001b[0m     )\n\u001b[0;32m   1089\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm2(\n\u001b[0;32m   1090\u001b[0m         x\n\u001b[1;32m-> 1091\u001b[0m         \u001b[38;5;241m+\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_mha_block\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m   1092\u001b[0m \u001b[43m            \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmemory\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmemory_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmemory_key_padding_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmemory_is_causal\u001b[49m\n\u001b[0;32m   1093\u001b[0m \u001b[43m        \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1094\u001b[0m     )\n\u001b[0;32m   1095\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm3(x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ff_block(x))\n\u001b[0;32m   1097\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\torch\\nn\\modules\\transformer.py:1127\u001b[0m, in \u001b[0;36mTransformerDecoderLayer._mha_block\u001b[1;34m(self, x, mem, attn_mask, key_padding_mask, is_causal)\u001b[0m\n\u001b[0;32m   1119\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_mha_block\u001b[39m(\n\u001b[0;32m   1120\u001b[0m     \u001b[38;5;28mself\u001b[39m,\n\u001b[0;32m   1121\u001b[0m     x: Tensor,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m   1125\u001b[0m     is_causal: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[0;32m   1126\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[1;32m-> 1127\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmultihead_attn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m   1128\u001b[0m \u001b[43m        \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1129\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmem\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1130\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmem\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1131\u001b[0m \u001b[43m        \u001b[49m\u001b[43mattn_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattn_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1132\u001b[0m \u001b[43m        \u001b[49m\u001b[43mkey_padding_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey_padding_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1133\u001b[0m \u001b[43m        \u001b[49m\u001b[43mis_causal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_causal\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1134\u001b[0m \u001b[43m        \u001b[49m\u001b[43mneed_weights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m   1135\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m   1136\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdropout2(x)\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\torch\\nn\\modules\\module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1734\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m   1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1736\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\torch\\nn\\modules\\module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m   1745\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1746\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1747\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m   1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\torch\\nn\\modules\\activation.py:1368\u001b[0m, in \u001b[0;36mMultiheadAttention.forward\u001b[1;34m(self, query, key, value, key_padding_mask, need_weights, attn_mask, average_attn_weights, is_causal)\u001b[0m\n\u001b[0;32m   1342\u001b[0m     attn_output, attn_output_weights \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39mmulti_head_attention_forward(\n\u001b[0;32m   1343\u001b[0m         query,\n\u001b[0;32m   1344\u001b[0m         key,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m   1365\u001b[0m         is_causal\u001b[38;5;241m=\u001b[39mis_causal,\n\u001b[0;32m   1366\u001b[0m     )\n\u001b[0;32m   1367\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1368\u001b[0m     attn_output, attn_output_weights \u001b[38;5;241m=\u001b[39m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmulti_head_attention_forward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m   1369\u001b[0m \u001b[43m        \u001b[49m\u001b[43mquery\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1370\u001b[0m \u001b[43m        \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1371\u001b[0m \u001b[43m        \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1372\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membed_dim\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1373\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_heads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1374\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43min_proj_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1375\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43min_proj_bias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1376\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias_k\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1377\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias_v\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1378\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd_zero_attn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1379\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdropout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1380\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mout_proj\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1381\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mout_proj\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1382\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtraining\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1383\u001b[0m \u001b[43m        \u001b[49m\u001b[43mkey_padding_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey_padding_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1384\u001b[0m \u001b[43m        \u001b[49m\u001b[43mneed_weights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mneed_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1385\u001b[0m \u001b[43m        \u001b[49m\u001b[43mattn_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattn_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1386\u001b[0m \u001b[43m        \u001b[49m\u001b[43maverage_attn_weights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maverage_attn_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1387\u001b[0m \u001b[43m        \u001b[49m\u001b[43mis_causal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_causal\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1388\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1389\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_first \u001b[38;5;129;01mand\u001b[39;00m is_batched:\n\u001b[0;32m   1390\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m attn_output\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m0\u001b[39m), attn_output_weights\n",
      "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\torch\\nn\\functional.py:6165\u001b[0m, in \u001b[0;36mmulti_head_attention_forward\u001b[1;34m(query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight, q_proj_weight, k_proj_weight, v_proj_weight, static_k, static_v, average_attn_weights, is_causal)\u001b[0m\n\u001b[0;32m   6163\u001b[0m q \u001b[38;5;241m=\u001b[39m q\u001b[38;5;241m.\u001b[39mview(tgt_len, bsz \u001b[38;5;241m*\u001b[39m num_heads, head_dim)\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m   6164\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m static_k \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m-> 6165\u001b[0m     k \u001b[38;5;241m=\u001b[39m \u001b[43mk\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mview\u001b[49m\u001b[43m(\u001b[49m\u001b[43mk\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbsz\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mnum_heads\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhead_dim\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m   6166\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m   6167\u001b[0m     \u001b[38;5;66;03m# TODO finish disentangling control flow so we don't do in-projections when statics are passed\u001b[39;00m\n\u001b[0;32m   6168\u001b[0m     \u001b[38;5;28;01massert\u001b[39;00m (\n\u001b[0;32m   6169\u001b[0m         static_k\u001b[38;5;241m.\u001b[39msize(\u001b[38;5;241m0\u001b[39m) \u001b[38;5;241m==\u001b[39m bsz \u001b[38;5;241m*\u001b[39m num_heads\n\u001b[0;32m   6170\u001b[0m     ), \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexpecting static_k.size(0) of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mbsz\u001b[38;5;250m \u001b[39m\u001b[38;5;241m*\u001b[39m\u001b[38;5;250m \u001b[39mnum_heads\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mstatic_k\u001b[38;5;241m.\u001b[39msize(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n",
      "\u001b[1;31mRuntimeError\u001b[0m: shape '[10, 256, 64]' is invalid for input of size 153600"
     ]
    }
   ],
   "source": [
    "decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)\n",
    "transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)\n",
    "memory = torch.rand(10, 32, 512)\n",
    "tgt = torch.rand(20, 32, 512)\n",
    "out = transformer_decoder(tgt, memory)\n",
    "print(out)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
